{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c64d3134-4dec-44ff-9f5e-55b033c48edb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset\n",
    "import time\n",
    "import glob\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from tqdm import tqdm_notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "01acdbde-f012-45db-ae94-beff97baa8f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_path = glob.glob('./train/*/*/*.jpg')\n",
    "np.random.shuffle(train_path)\n",
    "\n",
    "train_label = [int(x.split('/')[-3]) for x in train_path]\n",
    "\n",
    "test_path = glob.glob('./test/*.jpg')\n",
    "test_path.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "cd793915-cb4d-46a1-8724-376d52c321f8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "22312"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8e422af9-c5c8-4196-8dc4-1435143017a8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0, 1, 2, 3, 4}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "set(train_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7c87b426-c326-4a3e-a3fe-3c69ccf2fa5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['./test/001.jpg',\n",
       " './test/002.jpg',\n",
       " './test/003.jpg',\n",
       " './test/004.jpg',\n",
       " './test/005.jpg',\n",
       " './test/006.jpg',\n",
       " './test/007.jpg',\n",
       " './test/008.jpg',\n",
       " './test/009.jpg',\n",
       " './test/010.jpg']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_path[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ab48904c-3f33-4299-a6ab-c44ee9a40670",
   "metadata": {},
   "outputs": [],
   "source": [
    "class AverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "    def __init__(self, name, fmt=':f'):\n",
    "        self.name = name\n",
    "        self.fmt = fmt\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = 0\n",
    "        self.avg = 0\n",
    "        self.sum = 0\n",
    "        self.count = 0\n",
    "\n",
    "    def update(self, val, n=1):\n",
    "        self.val = val\n",
    "        self.sum += val * n\n",
    "        self.count += n\n",
    "        self.avg = self.sum / self.count\n",
    "\n",
    "    def __str__(self):\n",
    "        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n",
    "        return fmtstr.format(**self.__dict__)\n",
    "\n",
    "class ProgressMeter(object):\n",
    "    def __init__(self, num_batches, *meters):\n",
    "        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n",
    "        self.meters = meters\n",
    "        self.prefix = \"\"\n",
    "\n",
    "\n",
    "    def pr2int(self, batch):\n",
    "        entries = [self.prefix + self.batch_fmtstr.format(batch)]\n",
    "        entries += [str(meter) for meter in self.meters]\n",
    "        print('\\t'.join(entries))\n",
    "\n",
    "    def _get_batch_fmtstr(self, num_batches):\n",
    "        num_digits = len(str(num_batches // 1))\n",
    "        fmt = '{:' + str(num_digits) + 'd}'\n",
    "        return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n",
    "def validate(val_loader, model, criterion):\n",
    "    batch_time = AverageMeter('Time', ':6.3f')\n",
    "    losses = AverageMeter('Loss', ':.4e')\n",
    "    top1 = AverageMeter('Acc@1', ':6.2f')\n",
    "    progress = ProgressMeter(len(val_loader), batch_time, losses, top1)\n",
    "\n",
    "    # switch to evaluate mode\n",
    "    model.eval()\n",
    "\n",
    "    with torch.no_grad():\n",
    "        end = time.time()\n",
    "        for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n",
    "            input = input.cuda()\n",
    "            target = target.cuda()\n",
    "\n",
    "            # compute output\n",
    "            output = model(input)\n",
    "            loss = criterion(output, target)\n",
    "\n",
    "            # measure accuracy and record loss\n",
    "            acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n",
    "            losses.update(loss.item(), input.size(0))\n",
    "            top1.update(acc, input.size(0))\n",
    "            # measure elapsed time\n",
    "            batch_time.update(time.time() - end)\n",
    "            end = time.time()\n",
    "\n",
    "        # TODO: this should also be done with the ProgressMeter\n",
    "        print(' * Acc@1 {top1.avg:.3f}'\n",
    "              .format(top1=top1))\n",
    "        return top1\n",
    "\n",
    "def predict(test_loader, model, tta=10):\n",
    "    # switch to evaluate mode\n",
    "    model.eval()\n",
    "    \n",
    "    test_pred_tta = None\n",
    "    for _ in range(tta):\n",
    "        test_pred = []\n",
    "        with torch.no_grad():\n",
    "            end = time.time()\n",
    "            for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n",
    "                input = input.cuda()\n",
    "                target = target.cuda()\n",
    "\n",
    "                # compute output\n",
    "                output = model(input)\n",
    "                output = F.softmax(output, dim=1)\n",
    "                output = output.data.cpu().numpy()\n",
    "\n",
    "                test_pred.append(output)\n",
    "        test_pred = np.vstack(test_pred)\n",
    "    \n",
    "        if test_pred_tta is None:\n",
    "            test_pred_tta = test_pred\n",
    "        else:\n",
    "            test_pred_tta += test_pred\n",
    "    \n",
    "    return test_pred_tta\n",
    "\n",
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    batch_time = AverageMeter('Time', ':6.3f')\n",
    "    losses = AverageMeter('Loss', ':.4e')\n",
    "    top1 = AverageMeter('Acc@1', ':6.2f')\n",
    "    progress = ProgressMeter(len(train_loader), batch_time, losses, top1)\n",
    "\n",
    "    # switch to train mode\n",
    "    model.train()\n",
    "\n",
    "    end = time.time()\n",
    "    for i, (input, target) in enumerate(train_loader):\n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target = target.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output = model(input)\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # measure accuracy and record loss\n",
    "        losses.update(loss.item(), input.size(0))\n",
    "\n",
    "        acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100\n",
    "        top1.update(acc, input.size(0))\n",
    "\n",
    "        # compute gradient and do SGD step\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # measure elapsed time\n",
    "        batch_time.update(time.time() - end)\n",
    "        end = time.time()\n",
    "\n",
    "        if i % 100 == 0:\n",
    "            progress.pr2int(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "dde84a60-5939-4b69-b0a9-bb43773186b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "class XFDataset(Dataset):\n",
    "    def __init__(self, img_path, img_label, transform=None):\n",
    "        self.img_path = img_path\n",
    "        self.img_label = img_label\n",
    "        \n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        img = Image.open(self.img_path[index]).convert('RGB')\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        return img, torch.from_numpy(np.array(self.img_label[index]))\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "aad4267f-b007-48a4-8e06-cf89df9b9f2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'\n",
    "\n",
    "import timm\n",
    "model = timm.create_model('resnet18', pretrained=True, num_classes=5)\n",
    "model = model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "abc0c12c-5cdd-42fd-bfd6-9027431e46e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch:  0\n",
      "[   0/1091]\tTime  0.324 ( 0.324)\tLoss 4.5023e-03 (4.5023e-03)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.044 ( 0.046)\tLoss 7.8433e-02 (6.3462e-02)\tAcc@1  95.00 ( 97.67)\n",
      "[ 200/1091]\tTime  0.044 ( 0.044)\tLoss 6.1962e-03 (6.5124e-02)\tAcc@1 100.00 ( 97.84)\n",
      "[ 300/1091]\tTime  0.045 ( 0.044)\tLoss 1.7800e-01 (6.1890e-02)\tAcc@1  95.00 ( 98.02)\n",
      "[ 400/1091]\tTime  0.044 ( 0.044)\tLoss 2.4579e-03 (5.9559e-02)\tAcc@1 100.00 ( 98.18)\n",
      "[ 500/1091]\tTime  0.043 ( 0.044)\tLoss 2.0658e-03 (5.5116e-02)\tAcc@1 100.00 ( 98.31)\n",
      "[ 600/1091]\tTime  0.043 ( 0.044)\tLoss 2.6326e-03 (5.1129e-02)\tAcc@1 100.00 ( 98.40)\n",
      "[ 700/1091]\tTime  0.043 ( 0.043)\tLoss 5.6864e-03 (5.3521e-02)\tAcc@1 100.00 ( 98.35)\n",
      "[ 800/1091]\tTime  0.045 ( 0.043)\tLoss 1.4429e-01 (5.1473e-02)\tAcc@1  90.00 ( 98.37)\n",
      "[ 900/1091]\tTime  0.043 ( 0.043)\tLoss 6.3551e-03 (5.1053e-02)\tAcc@1 100.00 ( 98.38)\n",
      "[1000/1091]\tTime  0.042 ( 0.043)\tLoss 2.9711e-03 (4.9586e-02)\tAcc@1 100.00 ( 98.43)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2130898/2676678360.py:51: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for i, (input, target) in tqdm_notebook(enumerate(val_loader), total=len(val_loader)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8eac7481bdd40c193d9e0c9e9110cc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.800\n",
      "Epoch:  1\n",
      "[   0/1091]\tTime  0.244 ( 0.244)\tLoss 9.8599e-02 (9.8599e-02)\tAcc@1  95.00 ( 95.00)\n",
      "[ 100/1091]\tTime  0.043 ( 0.045)\tLoss 2.1680e-03 (3.1402e-02)\tAcc@1 100.00 ( 98.66)\n",
      "[ 200/1091]\tTime  0.045 ( 0.044)\tLoss 1.1284e-04 (2.5259e-02)\tAcc@1 100.00 ( 99.10)\n",
      "[ 300/1091]\tTime  0.048 ( 0.044)\tLoss 1.1205e-04 (2.6241e-02)\tAcc@1 100.00 ( 99.02)\n",
      "[ 400/1091]\tTime  0.043 ( 0.044)\tLoss 1.4951e-02 (3.1412e-02)\tAcc@1 100.00 ( 98.92)\n",
      "[ 500/1091]\tTime  0.044 ( 0.044)\tLoss 1.9477e-01 (3.7644e-02)\tAcc@1  95.00 ( 98.70)\n",
      "[ 600/1091]\tTime  0.040 ( 0.043)\tLoss 2.2300e-03 (3.7396e-02)\tAcc@1 100.00 ( 98.70)\n",
      "[ 700/1091]\tTime  0.039 ( 0.043)\tLoss 1.9081e-02 (3.5318e-02)\tAcc@1 100.00 ( 98.77)\n",
      "[ 800/1091]\tTime  0.043 ( 0.043)\tLoss 1.0084e-01 (3.4443e-02)\tAcc@1  95.00 ( 98.83)\n",
      "[ 900/1091]\tTime  0.042 ( 0.043)\tLoss 3.8712e-03 (3.3015e-02)\tAcc@1 100.00 ( 98.88)\n",
      "[1000/1091]\tTime  0.044 ( 0.043)\tLoss 3.4781e-04 (3.2336e-02)\tAcc@1 100.00 ( 98.92)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "80b2f87ed0e242c9b9bd3d2e9b940d75",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 100.000\n",
      "Epoch:  2\n",
      "[   0/1091]\tTime  0.259 ( 0.259)\tLoss 1.4724e-03 (1.4724e-03)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.045 ( 0.045)\tLoss 1.5199e-02 (1.5595e-02)\tAcc@1 100.00 ( 99.60)\n",
      "[ 200/1091]\tTime  0.042 ( 0.044)\tLoss 4.7581e-04 (1.6610e-02)\tAcc@1 100.00 ( 99.53)\n",
      "[ 300/1091]\tTime  0.044 ( 0.044)\tLoss 1.9028e-02 (1.9979e-02)\tAcc@1 100.00 ( 99.40)\n",
      "[ 400/1091]\tTime  0.043 ( 0.044)\tLoss 1.0502e-03 (2.2769e-02)\tAcc@1 100.00 ( 99.29)\n",
      "[ 500/1091]\tTime  0.043 ( 0.044)\tLoss 4.2360e-03 (2.0434e-02)\tAcc@1 100.00 ( 99.35)\n",
      "[ 600/1091]\tTime  0.040 ( 0.044)\tLoss 5.9249e-05 (1.9844e-02)\tAcc@1 100.00 ( 99.39)\n",
      "[ 700/1091]\tTime  0.047 ( 0.044)\tLoss 1.8846e-05 (1.7879e-02)\tAcc@1 100.00 ( 99.46)\n",
      "[ 800/1091]\tTime  0.043 ( 0.044)\tLoss 5.5814e-03 (1.8370e-02)\tAcc@1 100.00 ( 99.41)\n",
      "[ 900/1091]\tTime  0.047 ( 0.043)\tLoss 1.2090e-02 (1.9222e-02)\tAcc@1 100.00 ( 99.38)\n",
      "[1000/1091]\tTime  0.043 ( 0.043)\tLoss 1.3726e-03 (2.1988e-02)\tAcc@1 100.00 ( 99.28)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4b73c10272f442ba96906aa98fbccfc5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.000\n",
      "Epoch:  3\n",
      "[   0/1091]\tTime  0.256 ( 0.256)\tLoss 4.7306e-02 (4.7306e-02)\tAcc@1  95.00 ( 95.00)\n",
      "[ 100/1091]\tTime  0.041 ( 0.045)\tLoss 6.9081e-06 (1.3448e-02)\tAcc@1 100.00 ( 99.41)\n",
      "[ 200/1091]\tTime  0.043 ( 0.044)\tLoss 9.6763e-04 (1.4315e-02)\tAcc@1 100.00 ( 99.45)\n",
      "[ 300/1091]\tTime  0.036 ( 0.044)\tLoss 9.7373e-02 (1.1404e-02)\tAcc@1  95.00 ( 99.58)\n",
      "[ 400/1091]\tTime  0.045 ( 0.044)\tLoss 6.9967e-04 (1.2339e-02)\tAcc@1 100.00 ( 99.56)\n",
      "[ 500/1091]\tTime  0.042 ( 0.044)\tLoss 2.7265e-03 (1.2358e-02)\tAcc@1 100.00 ( 99.55)\n",
      "[ 600/1091]\tTime  0.042 ( 0.044)\tLoss 1.2706e-04 (1.2163e-02)\tAcc@1 100.00 ( 99.56)\n",
      "[ 700/1091]\tTime  0.043 ( 0.044)\tLoss 4.7863e-03 (1.2750e-02)\tAcc@1 100.00 ( 99.54)\n",
      "[ 800/1091]\tTime  0.044 ( 0.044)\tLoss 1.5516e-03 (1.3346e-02)\tAcc@1 100.00 ( 99.51)\n",
      "[ 900/1091]\tTime  0.043 ( 0.044)\tLoss 4.5972e-03 (1.4055e-02)\tAcc@1 100.00 ( 99.48)\n",
      "[1000/1091]\tTime  0.044 ( 0.044)\tLoss 1.3417e-02 (1.3932e-02)\tAcc@1 100.00 ( 99.49)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1c7347bda72745ba895cdb7de2cf8559",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.800\n",
      "Epoch:  4\n",
      "[   0/1091]\tTime  0.262 ( 0.262)\tLoss 1.2434e-03 (1.2434e-03)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.044 ( 0.045)\tLoss 1.0617e-04 (4.5660e-03)\tAcc@1 100.00 ( 99.85)\n",
      "[ 200/1091]\tTime  0.043 ( 0.044)\tLoss 3.6592e-04 (3.7412e-03)\tAcc@1 100.00 ( 99.93)\n",
      "[ 300/1091]\tTime  0.047 ( 0.044)\tLoss 5.0360e-05 (4.0069e-03)\tAcc@1 100.00 ( 99.93)\n",
      "[ 400/1091]\tTime  0.044 ( 0.044)\tLoss 6.4318e-05 (5.0542e-03)\tAcc@1 100.00 ( 99.86)\n",
      "[ 500/1091]\tTime  0.043 ( 0.044)\tLoss 9.8798e-04 (4.7829e-03)\tAcc@1 100.00 ( 99.85)\n",
      "[ 600/1091]\tTime  0.042 ( 0.044)\tLoss 3.8230e-05 (4.4126e-03)\tAcc@1 100.00 ( 99.86)\n",
      "[ 700/1091]\tTime  0.043 ( 0.044)\tLoss 2.6449e-04 (5.7047e-03)\tAcc@1 100.00 ( 99.84)\n",
      "[ 800/1091]\tTime  0.043 ( 0.044)\tLoss 3.7706e-02 (1.0854e-02)\tAcc@1 100.00 ( 99.66)\n",
      "[ 900/1091]\tTime  0.044 ( 0.044)\tLoss 2.2186e-02 (1.3616e-02)\tAcc@1 100.00 ( 99.55)\n",
      "[1000/1091]\tTime  0.043 ( 0.044)\tLoss 7.5228e-03 (1.4314e-02)\tAcc@1 100.00 ( 99.52)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ff22446f6171415fac1e1f033225b2d8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 100.000\n",
      "Epoch:  5\n",
      "[   0/1091]\tTime  0.273 ( 0.273)\tLoss 3.7738e-03 (3.7738e-03)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.043 ( 0.046)\tLoss 2.5157e-04 (5.4340e-03)\tAcc@1 100.00 ( 99.90)\n",
      "[ 200/1091]\tTime  0.042 ( 0.044)\tLoss 5.3610e-04 (9.9741e-03)\tAcc@1 100.00 ( 99.65)\n",
      "[ 300/1091]\tTime  0.043 ( 0.044)\tLoss 9.2436e-03 (8.5162e-03)\tAcc@1 100.00 ( 99.72)\n",
      "[ 400/1091]\tTime  0.044 ( 0.044)\tLoss 7.5036e-03 (1.0029e-02)\tAcc@1 100.00 ( 99.69)\n",
      "[ 500/1091]\tTime  0.043 ( 0.044)\tLoss 2.5220e-03 (1.0865e-02)\tAcc@1 100.00 ( 99.65)\n",
      "[ 600/1091]\tTime  0.044 ( 0.044)\tLoss 2.9781e-04 (1.1034e-02)\tAcc@1 100.00 ( 99.63)\n",
      "[ 700/1091]\tTime  0.044 ( 0.044)\tLoss 1.9654e-04 (1.0148e-02)\tAcc@1 100.00 ( 99.66)\n",
      "[ 800/1091]\tTime  0.042 ( 0.044)\tLoss 9.3135e-03 (1.0139e-02)\tAcc@1 100.00 ( 99.64)\n",
      "[ 900/1091]\tTime  0.041 ( 0.044)\tLoss 2.0706e-02 (1.1158e-02)\tAcc@1 100.00 ( 99.63)\n",
      "[1000/1091]\tTime  0.042 ( 0.044)\tLoss 1.7011e-02 (1.2312e-02)\tAcc@1 100.00 ( 99.60)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7215a5f8988450ba99ff5cd9a242492",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.600\n",
      "Epoch:  6\n",
      "[   0/1091]\tTime  0.255 ( 0.255)\tLoss 4.8645e-03 (4.8645e-03)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.043 ( 0.045)\tLoss 7.8462e-02 (1.0925e-02)\tAcc@1  95.00 ( 99.60)\n",
      "[ 200/1091]\tTime  0.044 ( 0.044)\tLoss 4.8311e-03 (9.4189e-03)\tAcc@1 100.00 ( 99.65)\n",
      "[ 300/1091]\tTime  0.043 ( 0.044)\tLoss 6.6987e-03 (1.0577e-02)\tAcc@1 100.00 ( 99.62)\n",
      "[ 400/1091]\tTime  0.043 ( 0.044)\tLoss 5.6467e-05 (1.1687e-02)\tAcc@1 100.00 ( 99.59)\n",
      "[ 500/1091]\tTime  0.043 ( 0.044)\tLoss 5.1439e-04 (1.3080e-02)\tAcc@1 100.00 ( 99.57)\n",
      "[ 600/1091]\tTime  0.043 ( 0.044)\tLoss 2.9773e-05 (1.3160e-02)\tAcc@1 100.00 ( 99.58)\n",
      "[ 700/1091]\tTime  0.045 ( 0.044)\tLoss 1.9180e-04 (1.2855e-02)\tAcc@1 100.00 ( 99.59)\n",
      "[ 800/1091]\tTime  0.044 ( 0.044)\tLoss 8.7255e-05 (1.3562e-02)\tAcc@1 100.00 ( 99.56)\n",
      "[ 900/1091]\tTime  0.040 ( 0.044)\tLoss 5.5155e-03 (1.3529e-02)\tAcc@1 100.00 ( 99.55)\n",
      "[1000/1091]\tTime  0.042 ( 0.044)\tLoss 9.3305e-04 (1.4700e-02)\tAcc@1 100.00 ( 99.53)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bc1c03b27caa440192fc8ad87f14bed5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.200\n",
      "Epoch:  7\n",
      "[   0/1091]\tTime  0.257 ( 0.257)\tLoss 1.2140e-03 (1.2140e-03)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.043 ( 0.045)\tLoss 2.5809e-06 (4.7557e-03)\tAcc@1 100.00 ( 99.80)\n",
      "[ 200/1091]\tTime  0.043 ( 0.044)\tLoss 8.4121e-04 (5.5801e-03)\tAcc@1 100.00 ( 99.83)\n",
      "[ 300/1091]\tTime  0.044 ( 0.044)\tLoss 4.5780e-05 (5.1677e-03)\tAcc@1 100.00 ( 99.85)\n",
      "[ 400/1091]\tTime  0.044 ( 0.044)\tLoss 1.2512e-03 (4.5028e-03)\tAcc@1 100.00 ( 99.86)\n",
      "[ 500/1091]\tTime  0.044 ( 0.044)\tLoss 2.1696e-06 (4.3899e-03)\tAcc@1 100.00 ( 99.88)\n",
      "[ 600/1091]\tTime  0.044 ( 0.044)\tLoss 8.2659e-05 (4.1196e-03)\tAcc@1 100.00 ( 99.88)\n",
      "[ 700/1091]\tTime  0.046 ( 0.044)\tLoss 4.6491e-06 (4.5201e-03)\tAcc@1 100.00 ( 99.86)\n",
      "[ 800/1091]\tTime  0.044 ( 0.044)\tLoss 3.7789e-06 (5.3785e-03)\tAcc@1 100.00 ( 99.84)\n",
      "[ 900/1091]\tTime  0.044 ( 0.044)\tLoss 5.0313e-05 (6.8424e-03)\tAcc@1 100.00 ( 99.79)\n",
      "[1000/1091]\tTime  0.043 ( 0.044)\tLoss 2.4302e-04 (8.0052e-03)\tAcc@1 100.00 ( 99.75)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49364aaba68d4c058e33fc56dc2f74aa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.800\n",
      "Epoch:  8\n",
      "[   0/1091]\tTime  0.253 ( 0.253)\tLoss 1.5648e-04 (1.5648e-04)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.044 ( 0.045)\tLoss 1.0886e-03 (5.5487e-03)\tAcc@1 100.00 ( 99.85)\n",
      "[ 200/1091]\tTime  0.043 ( 0.044)\tLoss 6.3645e-05 (3.8627e-03)\tAcc@1 100.00 ( 99.90)\n",
      "[ 300/1091]\tTime  0.044 ( 0.044)\tLoss 2.3239e-03 (3.7744e-03)\tAcc@1 100.00 ( 99.90)\n",
      "[ 400/1091]\tTime  0.044 ( 0.044)\tLoss 5.6316e-04 (3.2983e-03)\tAcc@1 100.00 ( 99.91)\n",
      "[ 500/1091]\tTime  0.042 ( 0.044)\tLoss 3.8400e-05 (3.6069e-03)\tAcc@1 100.00 ( 99.89)\n",
      "[ 600/1091]\tTime  0.043 ( 0.044)\tLoss 7.1493e-02 (3.8539e-03)\tAcc@1  95.00 ( 99.88)\n",
      "[ 700/1091]\tTime  0.044 ( 0.044)\tLoss 4.4943e-04 (3.4334e-03)\tAcc@1 100.00 ( 99.90)\n",
      "[ 800/1091]\tTime  0.044 ( 0.044)\tLoss 6.7409e-06 (3.0434e-03)\tAcc@1 100.00 ( 99.91)\n",
      "[ 900/1091]\tTime  0.042 ( 0.044)\tLoss 3.4373e-04 (3.4362e-03)\tAcc@1 100.00 ( 99.89)\n",
      "[1000/1091]\tTime  0.044 ( 0.044)\tLoss 2.3087e-04 (4.1418e-03)\tAcc@1 100.00 ( 99.88)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d7e6fdcfb2c24fa89572e7560b037cbd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 100.000\n",
      "Epoch:  9\n",
      "[   0/1091]\tTime  0.258 ( 0.258)\tLoss 2.6329e-04 (2.6329e-04)\tAcc@1 100.00 (100.00)\n",
      "[ 100/1091]\tTime  0.040 ( 0.045)\tLoss 9.9950e-05 (8.3957e-03)\tAcc@1 100.00 ( 99.75)\n",
      "[ 200/1091]\tTime  0.039 ( 0.045)\tLoss 1.0490e-06 (7.6670e-03)\tAcc@1 100.00 ( 99.78)\n",
      "[ 300/1091]\tTime  0.042 ( 0.044)\tLoss 1.0004e-03 (1.2978e-02)\tAcc@1 100.00 ( 99.63)\n",
      "[ 400/1091]\tTime  0.043 ( 0.044)\tLoss 6.5297e-04 (1.0959e-02)\tAcc@1 100.00 ( 99.69)\n",
      "[ 500/1091]\tTime  0.043 ( 0.044)\tLoss 1.5865e-05 (9.2654e-03)\tAcc@1 100.00 ( 99.73)\n",
      "[ 600/1091]\tTime  0.044 ( 0.044)\tLoss 1.3202e-04 (8.4265e-03)\tAcc@1 100.00 ( 99.75)\n",
      "[ 700/1091]\tTime  0.044 ( 0.044)\tLoss 2.5028e-03 (7.3998e-03)\tAcc@1 100.00 ( 99.78)\n",
      "[ 800/1091]\tTime  0.043 ( 0.044)\tLoss 5.5669e-06 (7.8045e-03)\tAcc@1 100.00 ( 99.78)\n",
      "[ 900/1091]\tTime  0.044 ( 0.044)\tLoss 7.8614e-05 (7.3023e-03)\tAcc@1 100.00 ( 99.79)\n",
      "[1000/1091]\tTime  0.044 ( 0.044)\tLoss 7.3809e-05 (6.9125e-03)\tAcc@1 100.00 ( 99.80)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3a6b9252294f4f2fb3275bbe22c115da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/25 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " * Acc@1 99.600\n"
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(train_path[:-500], train_label[:-500], \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.RandomHorizontalFlip(),\n",
    "                        transforms.RandomVerticalFlip(),\n",
    "                        transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=20, shuffle=True, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(train_path[-500:], train_label[-500:], \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=20, shuffle=False, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = torch.optim.Adam(model.parameters(), 0.005)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)\n",
    "best_acc = 0.0\n",
    "for epoch in range(10):\n",
    "    scheduler.step()\n",
    "    print('Epoch: ', epoch)\n",
    "\n",
    "    train(train_loader, model, criterion, optimizer, epoch)\n",
    "    val_acc = validate(val_loader, model, criterion)\n",
    "    \n",
    "    if val_acc.avg.item() > best_acc:\n",
    "        best_acc = round(val_acc.avg.item(), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "41e86d06-e364-4e3f-8631-c2f3c0722e35",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2130898/2676678360.py:81: TqdmDeprecationWarning: This function will be removed in tqdm==5.0.0\n",
      "Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`\n",
      "  for i, (input, target) in tqdm_notebook(enumerate(test_loader), total=len(test_loader)):\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c4894256376449e7b7e13cd81fbd39ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/112 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "test_loader = torch.utils.data.DataLoader(\n",
    "    XFDataset(test_path, [0] * len(test_path), \n",
    "            transforms.Compose([\n",
    "                        transforms.Resize((256, 256)),\n",
    "                        transforms.ToTensor(),\n",
    "                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "        ])\n",
    "    ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True\n",
    ")\n",
    "\n",
    "val_label = pd.DataFrame()\n",
    "val_label['img'] = [x.split('/')[-1] for x in test_path]\n",
    "val_label['label'] = predict(test_loader, model, 1).argmax(1)\n",
    "val_label.to_csv('submit.csv', index=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "5dc9d50b-9959-4cd7-acfb-91644ccddac6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>image</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>001.jpg</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>002.jpg</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>003.jpg</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>004.jpg</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>005.jpg</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4458</th>\n",
       "      <td>995.jpg</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4459</th>\n",
       "      <td>996.jpg</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4460</th>\n",
       "      <td>997.jpg</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4461</th>\n",
       "      <td>998.jpg</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4462</th>\n",
       "      <td>999.jpg</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>4463 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        image  label\n",
       "0     001.jpg      0\n",
       "1     002.jpg      1\n",
       "2     003.jpg      3\n",
       "3     004.jpg      4\n",
       "4     005.jpg      0\n",
       "...       ...    ...\n",
       "4458  995.jpg      3\n",
       "4459  996.jpg      1\n",
       "4460  997.jpg      0\n",
       "4461  998.jpg      0\n",
       "4462  999.jpg      1\n",
       "\n",
       "[4463 rows x 2 columns]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "val_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f350b9a1-82a7-4051-8492-3dc762d8aaab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11",
   "language": "python",
   "name": "py3.11"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
